-
Notifications
You must be signed in to change notification settings - Fork 35
[Transform] Attention/Cache transforms #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, though i have a number of questions and minor suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal is to use this generally for kv_cache and attn quantize, can we move the initialize_hooked_attention and initialize_hooked_kv_cache to initialize.py?
I understand we haven't hooked them in yet for those workflows but I think these belong there.
7bf4b57 to
75056bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do a pass through on any missing docstring, otherwise lgtm.
nice work
The base branch was changed.
e224a5d to
05ec17e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following for the most part. A few clarifications, but this makes sense to me
d084c5e to
e3f24d4
Compare
The base branch was changed.
145c9aa to
2efe3db
Compare
7c19358 to
04f716a
Compare
|
Last nightly worked, but e2e failed due to model storage issues |
4cc5ace to
9ead292
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can resolve the global var thread, I have another new comment we might want to consider in a follow-up but marking this as approved. Cool stuff! Excited to see it in action
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some questions. Otherwise, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of completeness, do you mind adding your kv_cache and attn quantized sample models to this PR description?
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
eff4729 to
8c99f63
Compare
Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
impressive work!
Signed-off-by: Kyle Sayers <[email protected]>
## Purpose ## * Support fully-expressive attention and kv cache quantization * Support running kv cache quantization evals with hf transformers * Resolves #1949 * Resolves #1928 ```python3 recipe = QuantizationModifier( config_groups={ "attention": QuantizationScheme( targets=["LlamaAttention"], input_activations=QuantizationArgs( num_bits=8, type="float", strategy="tensor" ), ) } ) ``` ```json { "quantization_config": { "config_groups": { "group_0": { "format": null, "input_activations": { "dynamic": false, "num_bits": 8, "observer": "minmax", "strategy": "tensor", "symmetric": true, "type": "float" }, "output_activations": null, "targets": [ "LlamaAttention" ], "weights": null } }, "format": "dense", "ignore": [], "kv_cache_scheme": { "dynamic": false, "group_size": null, "num_bits": 8, "observer": "minmax", "strategy": "tensor", "symmetric": true, "type": "float" }, "quant_method": "compressed-tensors", "quantization_status": "frozen", }, } ``` ## Prerequisites ## * Must be merged at the same time as vllm-project/compressed-tensors#436 ## Changes ## * Replace hooks * Remove `calibrate_kv_cache_input_hook`, `calibrate_kv_cache_output_hook`, `initialize_quantized_kv_cache` * Add `calibrate_query_hook` `calibrate_key_hook`, `calibrate_value_hook` * QuantizationMixin now initializes "q", "k", and "v" obsevers ([depending on the attached submodules](https://github.com/vllm-project/llm-compressor/pull/1651/files#diff-33303ae48e185b2fbb14dc45c2052805837deb5723248367b9579321c4c4e974R263-R270)) and adds the appropriate hooks * Miscellaneous * Fix minor shape bug in `_flatten_attention` * Add support for "attn_head" strategy in `_flatten_attention` * Tests * Removed old QuantizationKVCache tests (these classes are now tested [here])(https://github.com/neuralmagic/compressed-tensors/pull/436/files#diff-6e33ff48047dc4f7c9d969293f87e32e4d5ec3f3e8b741ea757780c8c0aab775) * Updated scale names to avoid using enum * Avoid unnecessary tokenization to reduce runtime ## Testing ## * Kv cache regression tests pass * Able to quantize attention with scripts (will add to examples once loadable in vllm) * kylesayrs/Llama-3.2-1B-Instruct-attention-fp8-head * kylesayrs/Llama-3.2-1B-Instruct-attention-nvfp4-head * Nightly passes (in progress) ## Evaluation ## <details><summary>eval.py</summary> ```python import sys import lm_eval model_id = sys.argv[1] print(model_id) results = lm_eval.simple_evaluate( # 3) hf serialized model="hf", model_args={ "pretrained": model_id, "add_bos_token": False, "dtype": "auto", "device_map": "cuda", #"max_length": 128000, }, device="cuda", # 3/) #tasks=["gsm8k_platinum", "mmlu_llama", "longbench2_single"], tasks=["gsm8k_platinum"], batch_size=64, apply_chat_template=True, fewshot_as_multiturn=True, ) print(model_id) print(lm_eval.utils.make_table(results)) ``` </details> <details><summary>compress.py</summary> ```python from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs # Select model and load it. #model_id = "Qwen/Qwen2.5-14B-Instruct-1M" model_id = "meta-llama/Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) # Select calibration dataset. DATASET_ID = "ultrachat_200k" DATASET_SPLIT = "train_sft" # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Configure the quantization algorithm to run. args = QuantizationArgs( num_bits=8, type="float", strategy="attn_head", symmetric=True, observer="static_minmax", ) recipe = QuantizationModifier( # config_groups={ # "attention": QuantizationScheme( # #targets=["Qwen2Attention"], # targets=["LlamaAttention"], # input_activations=args, # ) # } kv_cache_scheme=args, ) # Apply algorithms. oneshot( model=model, dataset=DATASET_ID, splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) # Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) sample = tokenizer("Hello my name is", return_tensors="pt") sample = {key: value.to(model.device) for key, value in sample.items()} output = model.generate(**sample, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk compressed. SAVE_DIR = model_id.rstrip("/").split("/")[-1] + f"-KV-FP8-{args.strategy}-{args.observer}" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) ``` </details> Model | GSM8K -- | -- nm-testing/Llama-3.1-8B-Instruct | 0.8337 nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Tensor | 0.8271 nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Head | 0.8354 nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Tensor | 0.8321 nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Head | 0.8238 --------- Signed-off-by: Kyle Sayers <[email protected]>
Purpose
Prerequisites
Changes
New Classes
QuantizedAttentionImplinjects itself into the model by registering a new attention implementation calledct_hooked_attentionoverridingmodel.config._attn_implementationto be the new implementation nameQuantizedKVCacheinjects itself into the model by overriding thepast_key_valuesinput kwarg to attention, and wrapping the functionality of the original cacheregister_query_hook,register_key_hookregister_value_hookQuantization Lifecycle Changes
initialize_hooked_kv_cacheinitialize_hooked_attentionif attention modules are explicitly targeted (seeis_narrow_match)initialize_module_for_quantizationQuantizationConfig.from_pretrainedwas cleaned up with additional commentskv_cache_schemefield is added if there are any attention modules with aquantization_schemeattachedHelpers
is_narrow_matchis used to check that attention modules are being specifically targeted (rather than targeting all modules in a layer)get_num_attn_heads,get_num_kv_heads,get_head_dimget attention config values from configTesting
is_narrow_matchEvaluation
eval.py
compress.py